{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Maintaining a batch of populations using the functional EvoTorch API\n",
    "\n",
    "## Motivation\n",
    "EvoTorch already implements mechanisms for accelerating the evaluation of solutions in a population using ray actors and/or PyTorch vectorization, which is essential for obtaining results in reasonable time.\n",
    "\n",
    "However, usually one also needs to run search algorithms multiple times to evaluate the effect of various hyperparameters (such as learning rate, mutation probability etc). This can take a lot of time, and needs additional effort to parallelize (e.g. on a cluster). The functional API of EvoTorch attempts to address this need by leveraging [PyTorch's vmap() transform](https://pytorch.org/docs/stable/generated/torch.func.vmap.html#torch.func.vmap).\n",
    "\n",
    "The main idea is that search algorithms in `evotorch.algorithms.functional` are pure functional implementations, and thus can easily be transformed (using `vmap()`) to operate on multiple populations stored as _batches of populations_ (or batches of batches of populations, and so on). With the help of such implementations, one can run multiple searches in parallel starting from different initial populations that cover different regions of the search space. As another example, one can run multiple searches that each have a different initial population as well as a corresponding learning rate.\n",
    "\n",
    "In this notebook, we demonstrate how a batch of populations, each originated from a different starting point, can be maintained so that different regions of the search space can be explored simultaneously. The key concepts covered are:\n",
    "- the ask-and-tell interface provided by search algorithms in the `evotorch.algorithms.functional` namespace.\n",
    "- the `@rowwise` decorator that makes it easier to write evaluation functions that can be automatically transformed using `vmap()`.\n",
    "- running multiple searches using vectorization simply by adding a batch dimension to arguments in the ask-and-tell interface and letting EvoTorch handle the rest."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1",
   "metadata": {},
   "source": [
    "We begin by importing the necessary libraries and defining some useful variables:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from evotorch.algorithms.functional import cem, cem_ask, cem_tell, pgpe, pgpe_ask, pgpe_tell\n",
    "from evotorch.decorators import rowwise\n",
    "\n",
    "from datetime import datetime\n",
    "import torch\n",
    "from math import pi\n",
    "\n",
    "# Use a GPU to achieve speedups from vectorization if possible\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "# We will search for 1000-dimensional solution vectors\n",
    "solution_length = 1000"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3",
   "metadata": {},
   "source": [
    "## Ask-and-tell\n",
    "Next, we implement a simple optimization loop for the commonly used Rastrigin function using the ask-and-tell interface for the Cross Entropy Method. An important detail to note is that we are directly evaluating the full population using the evaluation function `rastrigin()` so we need to implement it in a way that it operates on a population represented as a 2D Tensor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rastrigin(x: torch.Tensor) -> torch.Tensor:\n",
    "    n = x.shape[-1]\n",
    "    A = 10.0\n",
    "    return A * n + torch.sum((x ** 2) - (A * torch.cos(2 * pi * x)), dim=-1)\n",
    "\n",
    "\n",
    "# Uniformly sample center_init in [-5.12, 5.12], the typical domain of the rastrigin function\n",
    "center_init = (torch.rand(solution_length) * 2 - 1) * 5.12\n",
    "# Set std_max_change to 0.1 for all solution dimensions\n",
    "stdev_max_change = 0.1 * torch.ones(solution_length)\n",
    "\n",
    "cem_state = cem(\n",
    "    # We want to minimize the evaluation results\n",
    "    objective_sense=\"min\",\n",
    "\n",
    "    # `center_init` is the center point(s) of the initial search distribution(s).\n",
    "    center_init=center_init.to(device),\n",
    "\n",
    "    # The standard deviation of the initial search distribution.\n",
    "    stdev_init=10.0,\n",
    "\n",
    "    # We provide our batch of hyperparameter vectors as `stdev_max_change`.\n",
    "    stdev_max_change=stdev_max_change,\n",
    "\n",
    "    # Solutions belonging to the top half (top 50%) of the population(s)\n",
    "    # will be chosen as parents.\n",
    "    parenthood_ratio=0.5,\n",
    ")\n",
    "\n",
    "# We will run the evolutionary search for this many generations:\n",
    "num_generations = 1500\n",
    "\n",
    "# Interval (in seconds) for printing the status:\n",
    "report_interval = 3\n",
    "start_time = last_report_time = datetime.now()\n",
    "\n",
    "for generation in range(1, 1 + num_generations):\n",
    "    # Get a population from the evolutionary algorithm\n",
    "    population = cem_ask(cem_state, popsize=500)\n",
    "\n",
    "    # Compute the fitnesses\n",
    "    fitnesses = rastrigin(population)\n",
    "\n",
    "    # Inform the evolutionary algorithm of the fitnesses and get its next state\n",
    "    cem_state = cem_tell(cem_state, population, fitnesses)\n",
    "\n",
    "    # If it is time to report, print the status\n",
    "    tnow = datetime.now()\n",
    "    if ((tnow - last_report_time).total_seconds() > report_interval) or (generation == num_generations):\n",
    "        print(\"generation:\", generation, \"mean fitnesses:\", torch.mean(fitnesses, dim=-1))\n",
    "        last_report_time = tnow\n",
    "\n",
    "print(\"time taken: \", (last_report_time - start_time).total_seconds(), \"secs\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5",
   "metadata": {},
   "source": [
    "## The @rowwise decorator\n",
    "\n",
    "Next, we modify the code above so that multiple searches can be executed simultaneously taking advantage of PyTorch's vectorization capabilities. We modify the fitness function as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "@rowwise\n",
    "def rastrigin(x: torch.Tensor) -> torch.Tensor:\n",
    "    [n] = x.shape\n",
    "    A = 10.0\n",
    "    return A * n + torch.sum((x ** 2) - (A * torch.cos(2 * pi * x)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7",
   "metadata": {},
   "source": [
    "Notice how the fitness function above is decorated via `@rowwise`. This decorator tells EvoTorch that the user has defined the function to operate on its argument `x` as a vector (i.e. a 1-dimensional tensor). This makes it conceptually easier to implement the function and helps EvoTorch safely `vmap()` it in order to apply it to populations or a batch of populations as needed. `@rowwise` ensures that:\n",
    "\n",
    "- if the argument `x` is indeed received as a 1-dimensional tensor, the function works as how it is defined;\n",
    "- if the argument `x` is received as a matrix (i.e. as a 2-dimensional tensor), the operations of the function are applied for each row of the matrix;\n",
    "- if the argument `x` is received as a tensor with 3 or more dimensions, the operations of the function are applied for each row of each matrix.\n",
    "\n",
    "Thanks to this, the fitness function `rastrigin` can be used as it is to evaluate a single solution (represented by a 1-dimensional tensor), a single population (represented by a 2-dimensional tensor), or a batch of populations (represented by a tensor with 3 or more dimensions).\n",
    "\n",
    "_Note: We don't *have* to use `@rowwise` to implement our fitness function. Indeed, since our previous definition of `rastrigin()` happens to handle any number of batch dimensions and return a fitness value for each vector, we can use it as-is for running a batch of multiple searches. However, writing the fitness function in such a general way can often be difficult and error-prone, so it is much more convenient to use `@rowwise`._"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8",
   "metadata": {},
   "source": [
    "## Batched (vectorized) searches\n",
    "Using the modified rastrigin function above, we are almost ready to run a batch of searches utilizing additional vectorization over the number of searches.\n",
    "We will run 4 searches in a batch:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10",
   "metadata": {},
   "source": [
    "For both functional `cem` and functional `pgpe`, the hyperparameter `stdev_max_change` can be given as a scalar (which then will be expanded to a vector), or as a vector (which then will be used as it is), or as a batch of vectors (which will mean that for each batch item `i`, the `i`-th `stdev_max_change` vector will be used).\n",
    "\n",
    "Since we consider a batch of populations in this example, let us make a batch of starting points and `stdev_max_change` vectors, meaning that each population will have its own different starting point and `stdev_max_change` hyperparameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "center_inits = ((torch.rand(batch_size, solution_length) * 2) - 1) * 5.12\n",
    "# uniformly sample std_max_change between 0.01 and 0.2\n",
    "stdev_max_changes = torch.linspace(0.01, 0.2, batch_size)[:, None].expand(-1, solution_length)\n",
    "print(center_inits.shape, stdev_max_changes.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12",
   "metadata": {},
   "source": [
    "Next we simply provide these to the CEM state initializer and execute CEM using the ask-and-tell interface exacty as before. Internally, EvoTorch will recognize the new batch dimension and appropriately `vmap()` the fitness function for us!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "cem_state = cem(\n",
    "    objective_sense=\"min\",\n",
    "\n",
    "    # The batch of vectors `starting_points` is given as our `center_init`,\n",
    "    # that is, the center point(s) of the initial search distribution(s).\n",
    "    center_init=center_inits.to(device),\n",
    "\n",
    "    # The standard deviation of the initial search distribution(s).\n",
    "    stdev_init=10.0,\n",
    "\n",
    "    # We provide our batch of hyperparameter vectors as `stdev_max_change`.\n",
    "    stdev_max_change=stdev_max_changes,\n",
    "\n",
    "    parenthood_ratio=0.5,\n",
    ")\n",
    "\n",
    "start_time = last_report_time = datetime.now()\n",
    "\n",
    "for generation in range(1, 1 + num_generations):\n",
    "    # Get a population from the evolutionary algorithm\n",
    "    population = cem_ask(cem_state, popsize=500)\n",
    "\n",
    "    # Compute the fitnesses\n",
    "    fitnesses = rastrigin(population)\n",
    "\n",
    "    # Inform the evolutionary algorithm of the fitnesses and get its next state\n",
    "    cem_state = cem_tell(cem_state, population, fitnesses)\n",
    "\n",
    "    # If it is time to report, print the status\n",
    "    tnow = datetime.now()\n",
    "    if ((tnow - last_report_time).total_seconds() > report_interval) or (generation == num_generations):\n",
    "        print(\"generation:\", generation, \"mean fitnesses:\", torch.mean(fitnesses, dim=-1))\n",
    "        last_report_time = tnow\n",
    "\n",
    "print(\"time taken: \", (last_report_time - start_time).total_seconds(), \"secs\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14",
   "metadata": {},
   "source": [
    "If this notebook is executed on a GPU, the above batched search will take less time than `batch_size` times the time taken by the single search above, particularly for larger values of `batch_size`. Here are the center points found by CEM:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15",
   "metadata": {},
   "outputs": [],
   "source": [
    "cem_state.center"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16",
   "metadata": {},
   "source": [
    "## Another example\n",
    "\n",
    "As another example, let us consider the functional `pgpe` algorithm.\n",
    "For `pgpe`, `center_learning_rate` is a hyperparameter which is expected as a scalar in the non-batched case.\n",
    "If it is provided as a vector, this means that for each batch item `i`, the `i`-th value of the `center_learning_rate` vector will be used.\n",
    "\n",
    "Let us build a `center_learning_rate` vector:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17",
   "metadata": {},
   "outputs": [],
   "source": [
    "center_learning_rates = torch.linspace(0.001, 0.4, batch_size)\n",
    "center_learning_rates"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "18",
   "metadata": {},
   "source": [
    "Now we prepare the first state of our `pgpe` search:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19",
   "metadata": {},
   "outputs": [],
   "source": [
    "pgpe_state = pgpe(\n",
    "    # We want to minimize the evaluation results.\n",
    "    objective_sense=\"min\",\n",
    "\n",
    "    # The batch of vectors `starting_points` is given as our `center_init`,\n",
    "    # that is, the center point(s) of the initial search distribution(s).\n",
    "    center_init=center_inits.to(device),\n",
    "\n",
    "    # Standard deviation for the initial search distribution(s):\n",
    "    stdev_init=10.0,\n",
    "\n",
    "    # We provide our `center_learning_rate` batch here:\n",
    "    center_learning_rate=center_learning_rates,\n",
    "\n",
    "    # Learning rate for the standard deviation(s) of the search distribution(s):\n",
    "    stdev_learning_rate=0.1,\n",
    "\n",
    "    # We use the \"centered\" ranking where the worst solution is ranked -0.5,\n",
    "    # and the best solution is ranked +0.5:\n",
    "    ranking_method=\"centered\",\n",
    "\n",
    "    # We use the ClipUp optimizer.\n",
    "    optimizer=\"clipup\",\n",
    "\n",
    "    # Just like how we provide a batch of `center_learning_rate` values,\n",
    "    # we provide a batch of `max_speed` values for ClipUp:\n",
    "    optimizer_config={\"max_speed\": center_learning_rates * 2},\n",
    "\n",
    "    # Maximum relative change allowed for standard deviation(s) of the\n",
    "    # search distribution(s):\n",
    "    stdev_max_change=0.2,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20",
   "metadata": {},
   "source": [
    "Below is the main loop of the evolutionary search."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We will run the evolutionary search for this many generations:\n",
    "num_generations = 1500\n",
    "\n",
    "start_time = last_report_time = datetime.now()\n",
    "\n",
    "for generation in range(1, 1 + num_generations):\n",
    "    # Get a population from the evolutionary algorithm\n",
    "    population = pgpe_ask(pgpe_state, popsize=500)\n",
    "\n",
    "    # Compute the fitnesses\n",
    "    fitnesses = rastrigin(population)\n",
    "\n",
    "    # Inform the evolutionary algorithm of the fitnesses and get its next state\n",
    "    pgpe_state = pgpe_tell(pgpe_state, population, fitnesses)\n",
    "\n",
    "    # If it is time to report, print the status\n",
    "    tnow = datetime.now()\n",
    "    if ((tnow - last_report_time).total_seconds() > report_interval) or (generation == num_generations):\n",
    "        print(\"generation:\", generation, \"mean fitnesses:\", torch.mean(fitnesses, dim=-1))\n",
    "        last_report_time = tnow\n",
    "\n",
    "print(\"time taken: \", (last_report_time - start_time).total_seconds(), \"secs\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22",
   "metadata": {},
   "source": [
    "Here are the center points found by `pgpe`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23",
   "metadata": {},
   "outputs": [],
   "source": [
    "pgpe_state.optimizer_state.center"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
